import cv2
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from utils import *
import os
#from scipy.sparse import csr_matrix
from scipy.sparse.linalg import lsqr
from scipy.sparse import csr_matrix
toy_img = cv2.cvtColor(cv2.imread('samples/toy_problem.png'), cv2.COLOR_BGR2RGB)
toy_img = cv2.cvtColor(toy_img, cv2.COLOR_BGR2GRAY).astype('double') / 255.0
plt.imshow(toy_img, cmap="gray")
toy_img.shape
def toy_reconstruct(toy_img):
"""
The implementation for gradient domain processing is not complicated, but it is easy to make a mistake, so let's start with a toy example. Reconstruct this image from its gradient values, plus one pixel intensity. Denote the intensity of the source image at (x, y) as s(x,y) and the value to solve for as v(x,y). For each pixel, then, we have two objectives:
1. minimize (v(x+1,y)-v(x,y) - (s(x+1,y)-s(x,y)))^2
2. minimize (v(x,y+1)-v(x,y) - (s(x,y+1)-s(x,y)))^2
Note that these could be solved while adding any constant value to v, so we will add one more objective:
3. minimize (v(1,1)-s(1,1))^2
:param toy_img: numpy.ndarray
"""
rows, cols = toy_img.shape
im2var = np.arange(rows * cols).reshape(rows, cols)
print(im2var)
size_toy = toy_img.size
equations_num = 2 * size_toy + 1
print(equations_num, size_toy)
A = np.zeros(shape = (equations_num, size_toy))
b = np.zeros(shape = (equations_num, 1))
e = 0
for y in range(0,rows):
for x in range(0, cols - 1):
A[e][im2var[y][x+1]] = 1
A[e][im2var[y][x]] = -1
b[e] = toy_img[y][x+1] - toy_img[y][x]
e = e + 1
for y in range(0,rows -1):
for x in range(0, cols):
A[e][im2var[y+1][x]] = 1
A[e][im2var[y][x]] = -1
b[e] = toy_img[y+1][x] - toy_img[y][x]
e = e + 1
A[e][im2var[0][0]] = 1
b[e] = toy_img[0][0]
print("start")
print(A.shape)
print(b.shape)
v = lsqr(A, b)
print("done")
im_out = np.reshape(v[0], (rows, cols))
return im_out
im_out = toy_reconstruct(toy_img)
im_out
if im_out.any():
print("Error is: ", np.sqrt(((im_out - toy_img)**2).sum()))
#Images sanity check
fig, axes = plt.subplots(1, 2)
axes[0].imshow(toy_img,cmap='gray')
axes[1].imshow(im_out,cmap='gray')
axes[0].title.set_text('Original Image')
axes[1].title.set_text('Gradiant Domain Image')
fig.set_size_inches(10, 10)
plt.savefig('toy_problem.png')
# Feel free to change image
background_img = cv2.cvtColor(cv2.imread('samples/im2.JPG'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
plt.figure()
plt.imshow(background_img)
# Feel free to change image
object_img = cv2.cvtColor(cv2.imread('samples/penguin-chick.jpeg'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
import matplotlib.pyplot as plt
%matplotlib notebook
mask_coords = specify_mask(object_img)
xs = mask_coords[0]
ys = mask_coords[1]
%matplotlib inline
import matplotlib.pyplot as plt
plt.figure()
mask = get_mask(ys, xs, object_img)
%matplotlib notebook
import matplotlib.pyplot as plt
bottom_center = specify_bottom_center(background_img)
%matplotlib inline
import matplotlib.pyplot as plt
cropped_object, object_mask = align_source(object_img, mask, background_img, bottom_center)
def poisson_blend(cropped_object, object_mask, background_img, mask):
"""
:param cropped_object: numpy.ndarray One you get from align_source
:param object_mask: numpy.ndarray One you get from align_source
:param background_img: numpy.ndarray
"""
#TO DO
row_start = int(bottom_center[1] - (mask.shape[0]))
col_start = int(bottom_center[0] - (mask.shape[1]/2))
pad = 40
output = background_img
background_img = background_img[row_start - pad:row_start + mask.shape[0] + pad, col_start - pad: col_start + mask.shape [1] + pad]
cropped_object = cropped_object[row_start - pad:row_start + mask.shape[0] + pad, col_start - pad: col_start + mask.shape [1] + pad]
object_mask = object_mask[row_start - pad:row_start + mask.shape[0] + pad, col_start - pad: col_start + mask.shape [1] + pad]
#return cropped_object
#return cropped_object, background_img
background_rows = background_img.shape[0]
background_cols = background_img.shape[1]
background_count = background_rows * background_cols
output_mask = np.zeros(shape = (background_rows, background_cols,3))
#print(background_count)
im2var = np.arange(background_count).reshape(background_rows, background_cols)
#print(im2var)
for z in range(3):
v = []
sparse_value = []
sparse_row = []
sparse_col = []
b = []
e = 0
#A = np.zeros(shape=(background_count,background_count))
for y in range(background_rows):
for x in range(background_cols):
if not object_mask[y,x]: #background only
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
b.append(background_img[y,x,z])
e = e + 1
else:
if object_mask[y,x+1]:
sparse_value.append(-1)
sparse_row.append(e)
sparse_col.append(im2var[y,x + 1])
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
b.append(cropped_object[y,x,z] - cropped_object[y,x+1,z])
e = e + 1
if object_mask[y+1,x]:
sparse_value.append(-1)
sparse_row.append(e)
sparse_col.append(im2var[y+1,x])
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
b.append(cropped_object[y,x,z] - cropped_object[y+1,x,z])
e = e + 1
if not object_mask[y,x+1]:
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
b.append(cropped_object[y,x,z] - cropped_object[y,x+1,z] + background_img[y,x+1,z])
e = e + 1
if not object_mask[y,x-1]:
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
b.append(cropped_object[y,x,z] - cropped_object[y,x-1,z] + background_img[y,x-1,z])
e = e + 1
if not object_mask[y+1,x]:
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
b.append(cropped_object[y,x,z] - cropped_object[y+1,x,z] + background_img[y+1,x,z])
e = e + 1
if not object_mask[y-1,x]:
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
b.append(cropped_object[y,x,z] - cropped_object[y-1,x,z] + background_img[y-1,x,z])
e = e + 1
sparse_value = np.asarray(sparse_value)
sparse_row = np.asarray(sparse_row)
sparse_col = np.asarray(sparse_col)
b = np.asarray(b).T
#return 0
A = csr_matrix((sparse_value, (sparse_row, sparse_col)), shape=(e, background_count))
print(z)
v = lsqr(A, b)
send = np.clip(v[0], a_min = 0, a_max = 1)
output[row_start - pad:row_start + mask.shape[0] + pad, col_start - pad: col_start + mask.shape [1] + pad,z] = np.reshape(send, (background_rows, background_cols))
output_mask[:,:,z] = np.reshape(send, (background_rows, background_cols))
return output, output_mask
output, output_mask = poisson_blend(cropped_object, object_mask, background_img,mask)
fig, axes = plt.subplots(1, 2)
fig.set_size_inches(10, 10)
axes[0].imshow(output)
axes[1].imshow(output_mask)
plt.savefig('penguin_mountain.jpg')
background_img = cv2.cvtColor(cv2.imread('samples/feild.JPG'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
plt.figure()
plt.imshow(background_img)
# Feel free to change image
object_img = cv2.cvtColor(cv2.imread('samples/dragon_large.jpg'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
import matplotlib.pyplot as plt
%matplotlib notebook
mask_coords = specify_mask(object_img)
xs = mask_coords[0]
ys = mask_coords[1]
%matplotlib inline
import matplotlib.pyplot as plt
plt.figure()
mask = get_mask(ys, xs, object_img)
%matplotlib notebook
import matplotlib.pyplot as plt
bottom_center = specify_bottom_center(background_img)
%matplotlib inline
import matplotlib.pyplot as plt
cropped_object, object_mask = align_source(object_img, mask, background_img, bottom_center)
output, output_mask = poisson_blend(cropped_object, object_mask, background_img,mask)
fig, axes = plt.subplots(1, 2)
fig.set_size_inches(30, 30)
axes[0].imshow(output)
axes[1].imshow(output_mask)
plt.savefig('dragon_feild.jpg')
background_img = cv2.cvtColor(cv2.imread('samples/barkstall.JPG'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
plt.figure()
plt.imshow(background_img)
# Feel free to change image
object_img = cv2.cvtColor(cv2.imread('samples/lion.jpg'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
import matplotlib.pyplot as plt
%matplotlib notebook
mask_coords = specify_mask(object_img)
xs = mask_coords[0]
ys = mask_coords[1]
%matplotlib inline
import matplotlib.pyplot as plt
plt.figure()
mask = get_mask(ys, xs, object_img)
%matplotlib notebook
import matplotlib.pyplot as plt
bottom_center = specify_bottom_center(background_img)
%matplotlib inline
import matplotlib.pyplot as plt
cropped_object, object_mask = align_source(object_img, mask, background_img, bottom_center)
output, output_mask = poisson_blend(cropped_object, object_mask, background_img,mask)
fig, axes = plt.subplots(1, 2)
fig.set_size_inches(30, 30)
axes[0].imshow(output)
axes[1].imshow(output_mask)
plt.savefig('lion_feild.jpg')
def mix_blend(cropped_object, object_mask, background_img, mask):
"""
:param cropped_object: numpy.ndarray One you get from align_source
:param object_mask: numpy.ndarray One you get from align_source
:param background_img: numpy.ndarray
"""
#TO DO
row_start = int(bottom_center[1] - (mask.shape[0]))
col_start = int(bottom_center[0] - (mask.shape[1]/2))
pad = 40
output = background_img
background_img = background_img[row_start - pad:row_start + mask.shape[0] + pad, col_start - pad: col_start + mask.shape [1] + pad]
cropped_object = cropped_object[row_start - pad:row_start + mask.shape[0] + pad, col_start - pad: col_start + mask.shape [1] + pad]
object_mask = object_mask[row_start - pad:row_start + mask.shape[0] + pad, col_start - pad: col_start + mask.shape [1] + pad]
#return cropped_object
#return cropped_object, background_img
background_rows = background_img.shape[0]
background_cols = background_img.shape[1]
background_count = background_rows * background_cols
output_mask = np.zeros(shape = (background_rows, background_cols,3))
#print(background_count)
im2var = np.arange(background_count).reshape(background_rows, background_cols)
#print(im2var)
for z in range(3):
v = []
sparse_value = []
sparse_row = []
sparse_col = []
b = []
e = 0
#A = np.zeros(shape=(background_count,background_count))
for y in range(background_rows):
for x in range(background_cols):
if not object_mask[y,x]: #background only
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
b.append(background_img[y,x,z])
e = e + 1
else:
if object_mask[y,x+1]:
sparse_value.append(-1)
sparse_row.append(e)
sparse_col.append(im2var[y,x + 1])
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
background_gradiant = abs(background_img[y,x,z] - background_img[y,x+1,z])
source_gradiant = abs(cropped_object[y,x,z] - cropped_object[y,x+1,z])
if background_gradiant > source_gradiant:
b.append(background_img[y,x,z] - background_img[y,x+1,z])
e = e + 1
else:
b.append(cropped_object[y,x,z] - cropped_object[y,x+1,z])
e = e + 1
if object_mask[y+1,x]:
sparse_value.append(-1)
sparse_row.append(e)
sparse_col.append(im2var[y+1,x])
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
background_gradiant = abs(background_img[y,x,z] - background_img[y+1,x,z])
source_gradiant = abs(cropped_object[y,x,z] - cropped_object[y+1,x,z])
if background_gradiant > source_gradiant:
b.append(background_img[y,x,z] - background_img[y+1,x,z])
e = e + 1
else:
b.append(cropped_object[y,x,z] - cropped_object[y+1,x,z])
e = e + 1
if not object_mask[y,x+1]:
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
background_gradiant = abs(background_img[y,x,z] - background_img[y,x+1,z])
source_gradiant = abs(cropped_object[y,x,z] - cropped_object[y,x+1,z])
if background_gradiant > source_gradiant:
b.append(background_img[y,x,z] - background_img[y,x+1,z] + background_img[y,x+1,z])
e = e + 1
else:
b.append(cropped_object[y,x,z] - cropped_object[y,x+1,z] + background_img[y,x+1,z])
e = e + 1
if not object_mask[y,x-1]:
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
background_gradiant = abs(background_img[y,x,z] - background_img[y,x-1,z])
source_gradiant = abs(cropped_object[y,x,z] - cropped_object[y,x-1,z])
if background_gradiant > source_gradiant:
b.append(background_img[y,x,z] - background_img[y,x-1,z] + background_img[y,x-1,z])
e = e + 1
else:
b.append(cropped_object[y,x,z] - cropped_object[y,x-1,z] + background_img[y,x-1,z])
e = e + 1
if not object_mask[y+1,x]:
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
background_gradiant = abs(background_img[y,x,z] - background_img[y+1,x,z])
source_gradiant = abs(cropped_object[y,x,z] - cropped_object[y+1,x,z])
if background_gradiant > source_gradiant:
b.append(background_img[y,x,z] - background_img[y+1,x,z] + background_img[y+1,x,z])
e = e + 1
else:
b.append(cropped_object[y,x,z] - cropped_object[y+1,x,z] + background_img[y+1,x,z])
e = e + 1
if not object_mask[y-1,x]:
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
background_gradiant = abs(background_img[y,x,z] - background_img[y-1,x,z])
source_gradiant = abs(cropped_object[y,x,z] - cropped_object[y-1,x,z])
if background_gradiant > source_gradiant:
b.append(background_img[y,x,z] - background_img[y-1,x,z] + background_img[y-1,x,z])
e = e + 1
else:
b.append(cropped_object[y,x,z] - cropped_object[y-1,x,z] + background_img[y-1,x,z])
e = e + 1
sparse_value = np.asarray(sparse_value)
sparse_row = np.asarray(sparse_row)
sparse_col = np.asarray(sparse_col)
b = np.asarray(b).T
#return 0
A = csr_matrix((sparse_value, (sparse_row, sparse_col)), shape=(e, background_count))
print(z)
v = lsqr(A, b)
send = np.clip(v[0], a_min = 0, a_max = 1)
output[row_start - pad:row_start + mask.shape[0] + pad, col_start - pad: col_start + mask.shape [1] + pad,z] = np.reshape(send, (background_rows, background_cols))
output_mask[:,:,z] = np.reshape(send, (background_rows, background_cols))
return output, output_mask
#TO DO
pass
# Feel free to change image
background_img = cv2.cvtColor(cv2.imread('samples/im2.JPG'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
plt.figure()
plt.imshow(background_img)
# Feel free to change image
object_img = cv2.cvtColor(cv2.imread('samples/penguin-chick.jpeg'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
import matplotlib.pyplot as plt
%matplotlib notebook
mask_coords = specify_mask(object_img)
xs = mask_coords[0]
ys = mask_coords[1]
%matplotlib inline
import matplotlib.pyplot as plt
plt.figure()
mask = get_mask(ys, xs, object_img)
%matplotlib notebook
import matplotlib.pyplot as plt
bottom_center = specify_bottom_center(background_img)
%matplotlib inline
import matplotlib.pyplot as plt
cropped_object, object_mask = align_source(object_img, mask, background_img, bottom_center)
output, output_mask = mix_blend(cropped_object, object_mask, background_img, mask)
fig, axes = plt.subplots(1, 2)
fig.set_size_inches(10, 10)
axes[0].imshow(output)
axes[1].imshow(output_mask)
plt.savefig('penguin_mountain.jpg')
# Feel free to change image
background_img = cv2.cvtColor(cv2.imread('samples/pool.JPG'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
plt.figure()
plt.imshow(background_img)
# Feel free to change image
object_img = cv2.cvtColor(cv2.imread('samples/alligator_small.jpg'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
import matplotlib.pyplot as plt
%matplotlib notebook
mask_coords = specify_mask(object_img)
xs = mask_coords[0]
ys = mask_coords[1]
%matplotlib inline
import matplotlib.pyplot as plt
plt.figure()
mask = get_mask(ys, xs, object_img)
%matplotlib notebook
import matplotlib.pyplot as plt
bottom_center = specify_bottom_center(background_img)
%matplotlib inline
import matplotlib.pyplot as plt
cropped_object, object_mask = align_source(object_img, mask, background_img, bottom_center)
output, output_mask = mix_blend(cropped_object, object_mask, background_img, mask)
fig, axes = plt.subplots(1, 2)
fig.set_size_inches(30, 30)
axes[0].imshow(output)
axes[1].imshow(output_mask)
plt.savefig('pool_aligator.jpg')
output, output_mask = mix_blend(cropped_object, object_mask, background_img,mask)
fig, axes = plt.subplots(1, 2)
fig.set_size_inches(30, 30)
axes[0].imshow(output)
axes[1].imshow(output_mask)
plt.savefig('dragon_feild_mix_blend.jpg')
img = cv2.cvtColor(cv2.imread('samples/colorBlind4.png'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
fig, axes = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
axes.imshow(img)
img = cv2.cvtColor(cv2.imread('samples/colorBlind8.png'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
fig, axes = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
axes.imshow(img)
img = cv2.cvtColor(cv2.imread('samples/colorBlind4.png'), cv2.COLOR_BGR2GRAY).astype('double') / 255.0
fig, axes = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
axes.imshow(img, cmap = 'gray')
img = cv2.cvtColor(cv2.imread('samples/colorBlind8.png'), cv2.COLOR_BGR2GRAY).astype('double') / 255.0
fig, axes = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
axes.imshow(img, cmap = 'gray')
def color2gray(img):
color_image = cv2.cvtColor(cv2.imread(img), cv2.COLOR_BGR2RGB).astype('double') / 255.0
gray_image = cv2.cvtColor(cv2.imread(img), cv2.COLOR_BGR2GRAY).astype('double') / 255.0
rows = color_image.shape[0]
cols = color_image.shape[1]
im2var = np.arange(rows * cols).reshape(rows, cols)
output = np.zeros(shape = (rows, cols))
sparse_value = []
sparse_row = []
sparse_col = []
b = []
e = 0
for y in range(0,rows):
for x in range(cols - 1):
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y,x+1])
sparse_value.append(-1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
value_max_1 = (color_image[x,y,0])
value_max_2 = (color_image[x+1,y,0])
for z in range(1,3):
if value_max_1 < color_image[x,y,z]:
value_max_1 = (color_image[x,y,z])
if value_max_2 < color_image[x+1,y,z]:
value_max_2 = (color_image[x+1,y,z])
b.append(value_max_2 - value_max_1)
e = e + 1
for y in range(0,rows -1):
for x in range(0, cols):
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[y+1,x])
sparse_value.append(-1)
sparse_row.append(e)
sparse_col.append(im2var[y,x])
value_max_1 = (color_image[x,y,0])
value_max_2 = (color_image[x,y+1,0])
for z in range(1,3):
if value_max_1 < color_image[x,y,z]:
value_max_1 = (color_image[x,y,z])
if value_max_2 < color_image[x,y+1,z]:
value_max_2 = (color_image[x,y+1,z])
b.append(value_max_2 - value_max_1)
e = e + 1
sparse_value.append(1)
sparse_row.append(e)
sparse_col.append(im2var[0,0])
b.append((color_image[x,y,0] + color_image[x,y,1] + color_image[x,y,2]) / 3)
print(len(sparse_value),len(sparse_row),len(sparse_col),len(b))
print(rows*cols)
print(e)
sparse_value = np.asarray(sparse_value)
sparse_row = np.asarray(sparse_row)
sparse_col = np.asarray(sparse_col)
b = np.asarray(b).T
A = csr_matrix((sparse_value, (sparse_row, sparse_col)), shape=(e+1, rows*cols))
v = lsqr(A, b)
send = np.clip(v[0], a_min = 0, a_max = 1)
output = np.reshape(send, (rows, cols))
output = np.rot90(output)
output = cv2.flip(output, 0)
return output
pass
img = 'samples/colorBlind4.png'
gray_image = color2gray(img)
fig, axes = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
axes.imshow(gray_image, CMAP='gray')
img_2 = 'samples/colorBlind8.png'
gray_image = color2gray(img_2)
fig, axes = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
axes.imshow(gray_image, CMAP='gray')
def laplacian_blend(img1, img2, pyramid_height):
cols = img1.shape[1]
img1_img2 = np.hstack((img1[:, :int(cols/2)], img2[:, int(cols/2):]))
img1_g = []
img1_lp = []
img2_g = []
img2_lp = []
total = []
img1_g.append(img1)
img2_g.append(img2)
for x in range(pyramid_height):
img1 = cv2.pyrDown(img1)
img2 = cv2.pyrDown(img2)
img1_g.append(img1)
img2_g.append(img2)
img1_lp.append(img1_g[pyramid_height - 1])
img2_lp.append(img2_g[pyramid_height - 1])
for x in range(pyramid_height - 1,0,-1):
lp1 = img1_g[x-1] - cv2.pyrUp(img1_g[x])
img1_lp.append(lp1)
lp2 = img2_g[x-1] - cv2.pyrUp(img2_g[x])
img2_lp.append(lp2)
for lp1, lp2 in zip(img1_lp, img2_lp):
cols = lp1.shape[1]
lp = np.hstack((lp1[:, 0:int(cols/2)], lp2[:, int(cols/2):]))
total.append(lp)
output = total[0]
for x in range(1,pyramid_height):
output = cv2.pyrUp(output)
output = total[x] + output
output = np.clip(output, a_min = 0, a_max = 1)
return output, img1_img2
tarik = cv2.cvtColor(cv2.imread('samples/tarik.jpg'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
amar = cv2.cvtColor(cv2.imread('samples/amar.jpg'), cv2.COLOR_BGR2RGB).astype('double') / 255.0
fig, axes = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
axes.imshow(tarik)
fig, axes = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
axes.imshow(amar)
output, non_blend_output = laplacian_blend(tarik, amar,5)
fig, axes = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
axes.imshow(non_blend_output)
fig, axes = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
axes.imshow(output)